import numpy as np
import torch as pt
from torch.nn.functional import one_hot

x = pt.tensor([0, 2, 8, 1, 1, 2, 3])
print(x)
n_cls = x.max() + 1
print(n_cls)

x = one_hot(x)
print(x)
